/*
Copyright (c) 2022 Wojciech Nawrocki. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Wojciech Nawrocki
*/

-- This is a Penrose (https://penrose.cs.cmu.edu/) style file for commutative diagrams.

const {
    -- Minimum size of a row/column in the object grid.
    gridMinSize = 50
    gridPadHorizontal = 50
    gridPadVertical = 50
    curveOffset = 20
    labelOffsetFrac = 0.3
}

layout = [Grid, Arrows, EnsureOnCanvas]

forall Targettable x {
    -- The `center` is the "root" position of a targettable. Cells targetting
    -- this targettable start/end at positions relative to the center.
    -- HACK: The optimization variable has to be defined in terms of two separate
    -- floats to force Penrose to track dependencies at this granularity rather
    -- than treating the whole vector as one, unbreakable entity. Otherwise, it
    -- thinks the assignment
    --   u ← (u[0], v[1]); v ← (u[0], v[1])
    -- is circular.
    float x.centerX = ? in [Grid, EnsureOnCanvas]
    float x.centerY = ? in [Grid, EnsureOnCanvas]
    vec2 x.center = (x.centerX, x.centerY)
    -- The `textBoxCenter` is where the text box is. It may differ from `center`
    -- when the label is offset towards one side.
    vec2 x.textBoxCenter = x.center

    shape x.textBox = Rectangle {
        center: x.textBoxCenter
        cornerRadius: 5
        -- Default size, will be overwritten dynamically.
        width: 50
        height: 30
        fillColor: theme.tooltipBackground
        -- Draw a border.
        strokeWidth: 1
        strokeColor: theme.tooltipBorder
    }

    -- NOTE: Not used anymore, we use the optimization engine instead.
    -- Sides of a Targettable where the ends of cells targetting it are attached.
    -- +-T-+
    -- L C R
    -- +-B-+
    -- vec2 x.left = x.center - (x.textBox.width / 2, 0)
    -- vec2 x.right = x.center + (x.textBox.width / 2, 0)
    -- vec2 x.top = x.center + (0, x.textBox.height / 2)
    -- vec2 x.bottom = x.center - (0, x.textBox.height / 2)

    -- NOTE: Useful for debugging but we don't want it in the actual output.
    -- shape x.text = Equation {
    --     center: x.textBox.center
    --     string: x.label
    --     fontSize: "11pt"
    --     ensureOnCanvas: false
    -- }
    -- x.textLayering = x.text above x.textBox
}

forall Object A {
    -- We optimize the sizes of every row and column in a grid which determines
    -- object placement. Each object keeps track of the size of the grid column
    -- to the right of it, and the grid row below it. Positioning predicates propagate
    -- these, for example IsLeftHorizontal/IsRightHorizontal propagates gridBelow.
    -- Sizes propagate to the right and downwards so that the leftmost, uppermost
    -- object is the "root" of the diagram.
    --
    -- Penrose objectives are arbitrary reals, so to obtain a positive value we make
    -- the it the square of an objective. The size of a grid row/column is then the minimal,
    -- fixed size `gridMinSize` plus an optimized offset.
    float preGridRight = ? in Grid
    float preGridBelow = ? in Grid
    float gridRight = preGridRight*preGridRight
    float gridBelow = preGridBelow*preGridBelow
    float A.gridRight = const.gridMinSize + gridRight
    float A.gridBelow = const.gridMinSize + gridBelow

    encourage minimal(gridRight) in Grid
    encourage minimal(gridBelow) in Grid
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B) {
    -- Cells are drawn as arrows between the centres of their endpoints,
    -- but shortened on both ends to make room for the text boxes.
    float preOffStart = ? in Arrows
    float preOffEnd = ? in Arrows
    float offStart = preOffStart*preOffStart
    float offEnd = preOffEnd*preOffEnd
    vec2 f.dir = normalize(B.center - A.center)
    vec2 f.start = A.center + f.dir * offStart
    vec2 f.end = B.center - f.dir * offEnd
    vec2 f.mid = (f.start + f.end) / 2

    -- Make arrows as long as possible (offsets as short as possible) subject
    -- to them being outside the labels.

    -- NOTE: Getting this constraint right is tricky. The optimized variable is a single scalar
    -- moving the start (or end) of the arrow along the direction between the endpoints. But for
    -- an arrow start moving horizontally inside a wide rectangle, the SDF is the vertical direction
    -- which is constant at most position and so the gradient is 0.

    -- With staged layout the `disjoint` constraints work quite well.
    ensure disjoint(f.shape, A.textBox) in Arrows
    ensure disjoint(f.shape, B.textBox) in Arrows
    float distA = signedDistance(A.textBox, f.start)
    float distB = signedDistance(B.textBox, f.end)
    ensure lessThan(distA, 4) in Arrows
    ensure lessThan(distB, 4) in Arrows

    -- The center of a cell is not optimized but rather defined as its midpoint.
    override f.centerX = f.mid[0]
    override f.centerY = f.mid[1]

    -- How much space this cell needs horizontally and vertically
    float f.spaceH = A.textBox.width / 2 + f.textBox.width + B.textBox.width / 2 + const.gridPadHorizontal
    float f.spaceV = A.textBox.height / 2 + f.textBox.height + B.textBox.height / 2 + const.gridPadVertical

    shape f.shape = Line {
        start: f.start
        end: f.end
        endArrowhead: "line"
        strokeColor: theme.foreground
        ensureOnCanvas: false
    }

    -- Arrows and their labels are not moved in the Grid stage, so this objective has to be disabled
    -- in order to make convergence possible.
    override f.textBox.ensureOnCanvas = false

    -- NOTE: This is only necessary for the playground. The Lean widget programatically positions
    -- the diagram after optimization.
    -- ensure onCanvas(f.textBox, canvas.width, canvas.height) in EnsureOnCanvas
    -- ensure onCanvas(f.shape, canvas.width, canvas.height) in EnsureOnCanvas

    f.cellTextLayering = f.textBox above f.shape
    f.cellTextLayeringA = A.textBox above f.shape
    f.cellTextLayeringB = B.textBox above f.shape
}

forall Cell a; Cell f; Cell g
where a := MakeCell(f, g) {
    -- TODO: need a custom svg function to create nice double arrow for 2-cells
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsLeftHorizontal(f) {
    override A.centerX = B.centerX + B.gridRight
    override A.centerY = B.centerY
    override A.gridBelow = B.gridBelow

    ensure lessThan(f.spaceH, B.gridRight) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsRightHorizontal(f) {
    override B.centerX = A.centerX + A.gridRight
    override B.centerY = A.centerY
    override B.gridBelow = A.gridBelow

    ensure lessThan(f.spaceH, A.gridRight) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsUpVertical(f) {
    override A.centerX = B.centerX
    override A.centerY = B.centerY - B.gridBelow
    override A.gridRight = B.gridRight

    ensure lessThan(f.spaceV, B.gridBelow) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsDownVertical(f) {
    override B.centerX = A.centerX
    override B.centerY = A.centerY - A.gridBelow
    override B.gridRight = A.gridRight

    ensure lessThan(f.spaceV, A.gridBelow) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsLeftUpDiagonal(f) {
    override A.centerX = B.centerX + B.gridRight
    override A.centerY = B.centerY - B.gridBelow

    encourage lessThan(f.spaceH, B.gridRight) in Grid
    encourage lessThan(f.spaceV, B.gridBelow) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsLeftDownDiagonal(f) {
    -- Convention: grid constraints propagate to the right and downwards
    -- so the equalities here are flipped.
    override A.centerX = B.centerX + B.gridRight
    override B.centerY = A.centerY - A.gridBelow

    ensure lessThan(f.spaceH, B.gridRight) in Grid
    ensure lessThan(f.spaceV, A.gridBelow) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsRightUpDiagonal(f) {
    override B.centerX = A.centerX + A.gridRight
    override A.centerY = B.centerY - B.gridBelow

    ensure lessThan(f.spaceH, A.gridRight) in Grid
    ensure lessThan(f.spaceV, B.gridBelow) in Grid
}

forall Cell f; Object A; Object B
where f := MakeCell(A, B); IsRightDownDiagonal(f) {
    override B.centerX = A.centerX + A.gridRight
    override B.centerY = A.centerY - A.gridBelow

    ensure lessThan(f.spaceH, A.gridRight) in Grid
    ensure lessThan(f.spaceV, A.gridBelow) in Grid
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsCurvedLeft(f) {
    vec2 dirLeft = rot90(f.dir)
    vec2 newCenter = f.mid + dirLeft * const.curveOffset
    override f.centerX = newCenter[0]
    override f.centerY = newCenter[1]
    override f.shape = Path {
        d: interpolateQuadraticFromPoints("open", f.start, f.center, f.end)
        endArrowhead: "line"
        strokeColor: rgba(0,0,0,100)
        ensureOnCanvas: false
    }
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsCurvedRight(f) {
    vec2 dirRight = rot90(rot90(rot90(f.dir))) -- lol
    vec2 newCenter = f.mid + dirRight * const.curveOffset
    override f.centerX = newCenter[0]
    override f.centerY = newCenter[1]
    override f.shape = Path {
        d: interpolateQuadraticFromPoints("open", f.start, f.center, f.end)
        endArrowhead: "line"
        strokeColor: rgba(0,0,0,100)
        ensureOnCanvas: false
    }
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsLabelLeft(f) {
    vec2 dirLeft = rot90(f.dir)
    -- Distance from arrow midpoint to `dirLeft` point on box if the box were an ellipse
    float dist = norm((dirLeft[0] * f.textBox.width, dirLeft[1] * f.textBox.height))
    override f.textBoxCenter = f.center + dirLeft * dist * const.labelOffsetFrac
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsLabelRight(f) {
    vec2 dirRight = rot90(rot90(rot90(f.dir)))
    -- Distance from arrow midpoint to `dirRight` point on box if the box were an ellipse
    float dist = norm((dirRight[0] * f.textBox.width, dirRight[1] * f.textBox.height))
    override f.textBoxCenter = f.center + dirRight * dist * const.labelOffsetFrac
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsDashed(f) {
    override f.shape.strokeStyle = "dashed"
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsMono(f) {
    override f.shape.startArrowhead = "line"
    override f.shape.flipStartArrowhead = true
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsEpi(f) {
    override f.shape.endArrowhead = "doubleLine"
}

forall Cell f; Targettable A; Targettable B
where f := MakeCell(A, B); IsEmbedding(f) {
    override f.shape.startArrowhead = "loopup"
}
